import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights, efficientnet_b0, efficientnet_v2_s, EfficientNet_V2_S_Weights,\
	EfficientNet_B0_Weights, EfficientNet_B1_Weights
from cvx2.wrapper import SplitImageClassifyModelWrapper
from cvx2.utils import get_pretrained


# def get_pretrained(num_classes):
# 	pretrained = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
# 	# pretrained.classifier[1].out_features = num_classes  # 打印out_features已改为6，但输出仍为1000
# 	pretrained.classifier[1] = nn.Linear(in_features=pretrained.classifier[1].in_features, out_features=num_classes)
	
# 	pretrained.features.requires_grad_(False)
	
# 	# print(len(list(pretrained.parameters())))
	
# 	# 冻结整个网络
# 	# for param in pretrained.parameters():
# 	# 	param.requires_grad_(False)
# 		# print(name, param.numel())
	
# 	# for name, param in pretrained.named_parameters():
# 	# 	if 'classifier' in name:
# 	# 		break
# 	# 	param.requires_grad_(False)
	
# 	return pretrained


if __name__ == '__main__':
	# pretrained = get_pretrained(2)
	pretrained = get_pretrained(resnet18, ResNet18_Weights.DEFAULT, 2)
	wrapper = SplitImageClassifyModelWrapper(pretrained)
	wrapper.train(data='/Users/summy/data/AI-face',
	              transform=ResNet18_Weights.DEFAULT.transforms(),
	              # train_transform=train_transform,
	              batch_size=32, epochs=2, lr=0.001, T_max=3)
